#!/bin/env python
#-*- coding:utf-8 -*-
author="将军"

import sys
import time
import numpy
import random as ran
import pandas
import operator
from numpy import *

class LoadData(object):
    #'加载数据'
    def irisReadData(self,filename):
        #'鸢尾花数据加载'
        print('正在加载数据...')
        data=pandas.read_csv(filename,encoding="utf-8")
        print('数据加载完成!')
        return data

    def cutData(self,data,num):
        #'随机切割训练数据'
        #'data="原数据",num="随机次数"'
        print('正在切割数据...')
        all_data=[]
            #'全部训练数据'
        for i in data.values:
            all_data.append(list(i))
        test_data=[]
            #'测试数据'
        for i in range(num):
            item=ran.choice(all_data)
            test_data.append(item)
            all_data.remove(item)
        train_data=pandas.DataFrame(all_data)
            #'重新转换剩余训练数据为数据框形式'
        print('数据切割完成!')
        return test_data,train_data

    def trainData(self,traindata):
        #'创建训练数据'
        #'traindata="切割后的训练数据"'
        print('正在创建训练数据...')
        labels=[]
            #'所有类别'
        for i in traindata.values:
            labels.append(i[4])
        traits=zeros((len(labels),4))
            #'所有类别对应的单个特征的集合'
        for i in range(0,len(labels)):
            traits[i,:]=traindata.values[i][:4]
        print('训练数据创建完成...')
        return labels,traits


class Knn(object):
    #'knn算法对鸢尾花进行预测分类'
    def knn(self,k,testdata,traindata,labels):
        #'knn算法,k="前k个训练数据,"testdata="测试数据",traindata="训练数据",labels="训练数据的类别"'
        train_data_size=traindata.shape[0]
            #'训练数据个数'
        dif=tile(testdata,(train_data_size,1))-traindata
            #'测试数据和训练数据的差值'
        sqdif=dif**2
            #'测试数据和训练数据差值的平方'
        sumsqdif=sqdif.sum(axis=1)
            #'测试数据和训练数据差值的平方和'
        distance=sumsqdif**0.5
            #'测试数据和训练数据的欧氏距离(即平方根)'
        sortdistance=distance.argsort()
            #'测试数据到各训练数据的距离排序后的结果'
        count={}
            #'测试数据距离训练数据的统计结果'
        for i in range(0,k):
            vote=labels[sortdistance[i]]
                #'测试数据与训练数据最近的那个训练数据的类别'
            count[vote]=count.get(vote,0)+1
                #'测试数据与训练数据匹配的当前训练数据的类别,相同的类别累加'
        sortcount=sorted(count.items(),key=operator.itemgetter(1),reverse=True)
            #'测试数据所匹配训练数据中按距离最近排序后的,排序按照统计的数量'
        return sortcount[0][0]


class Bayes(object):
    #'bayes(朴素贝叶斯)算法对鸢尾花进行分类'
    def __init__(self):
        #'初始化数据'
        self.length=-1
            #'数据长度'
        self.label_count=dict()
            #'各类别概率集合'
        self.vector_count=dict()
            #'类别对应特征集合'

    def fit(self,traindata,labels):
        #'使用训练数据进行训练'
        if len(traindata)!=len(labels):
            raise ValueError('训练数据数量与类别数量不一致')
        print('正在使用训练数据训练,请等待...')
        self.length=len(traindata[0])
            #'单个数据特征长度'
        labels_num=len(labels)
        nor_labels=set(labels)
        for item in nor_labels:
            self.label_count[item]=labels.count(item)/float(labels_num)
                #'当前类别在总类别中的占比=当前类别在总类别中的数量/总类别数量'
        corss_data=zip(traindata,labels)
            #'使用zip()函数是训练数据特征与其对应类别一一对应'
        for vector,label in corss_data:
            if label not in self.vector_count:
                self.vector_count[label]=[]
            self.vector_count[label].append(vector)
        print('使用训练数据训练完成!')
        return self

    def bayesTest(self,testdata,labels):
        #'对测试数据进行预测分类'
        #'testdata="单个数据的特征集合",labels="训练数据类别集合"'
        if self.length==-1:
            raise ValueError('没有进行训练无法进行预测,请先训练')
        #print('正在预测测试数据类别,请等待...')
        categorys=dict()
            #'类别的概率集合'
        for t in labels:
            #'计算每个类别的概率对应与类别概率集合'
            probability=float(1)
                #'测试数据所有特征的总概率'
            label_probability=self.label_count[t]
                #'当前类别的概率'
            all_vector=self.vector_count[t]
                #'当前类别的所有特征'
            num=len(all_vector)
                #'当前类别的特征向量总数量'
            all_vector=numpy.array(all_vector).T.tolist()
                #'将当前类别所有特征向量转换为数组,并将行列转制'
            for index in range(0,len(testdata[:4])):
                #'index为单个测试数据中的对应的单个特征'
                vector=all_vector[index]
                    #'与当前测试数据特征向量对应的该类别的所有特征向量'
                probability*=float(vector.count(testdata[:4][index])/num)
            categorys[t]=probability*label_probability
                #'当前类别的概率=当前类别特征的总概率*当前类别的概率'
        result=sorted(categorys,key=lambda x:categorys[x],reverse=True)[0]
        #print('测试数据预测完成!')
        return result

def knnMain(k,data,tst_data,trn_data,labels,traits):
    #'主程序'
    x=0
    for t in tst_data:
        t_data=numpy.array(t[:4])
        rst=k.knn(10,t_data,traits,labels)
        print('测试数据:%s-->预测结果为:%s' %(t,rst))
        if t[4]==rst:
            x+=1
        else:
            print('该条预测结果为错误,请查看!')
    return str(x/len(tst_data)*100)+"%"

def bayesMain(b,data,tst_data,trn_data,labels,traits):
    b.fit(traits,labels)
    x=0
    for t in tst_data:
        rst=b.bayesTest(t,labels)
        print('测试数据:%s-->预测结果为:%s' %(t,rst))
        if t[4]==rst:
            x+=1
        else:
            print('该条预测结果为错误,请查看!')
    return str(x/len(tst_data)*100)+"%"


if __name__=="__main__":
    k=Knn()
    b=Bayes()
    l=LoadData()
    filename='iris.csv'
    data=l.irisReadData(filename)
    print('训练数据总数量:',len(data))
    tst_data,trn_data=l.cutData(data,30)
    print('切割后的测试数据数量:%s,训练数据数量:%s' %(len(tst_data),len(trn_data)))
    labels,traits=l.trainData(trn_data)
    print('创建的训练数据类别数量:%s,训练数据特征数量:%s' %(len(labels),len(traits)))
    k_time_start=time.time()
    k_precision_rate=knnMain(k,data,tst_data,trn_data,labels,traits)
    k_time_end=time.time()
    b_time_start=time.time()
    b_precision_rate=bayesMain(b,data,tst_data,trn_data,labels,traits)
    b_time_end=time.time()
    print('knn算法准确率:%s' %(k_precision_rate))
    print('bayes算法准确率:%s' %(b_precision_rate))
    print('knn算法耗时:%s(ms)' %(int((k_time_end-k_time_start)*1000)))
    print('bayes算法耗时:%s(ms)' %(int((b_time_end-b_time_start)*1000)))
    print('通过对比明显knn算法比bayes算法耗时更少,准确率更高,所以选择使用Knn算法进行预测分类')
